Gemma/[Gemma_2]Prompt_chaining.ipynb (597 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Tce3stUlHN0L"
},
"source": [
"##### Copyright 2024 Google LLC."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "tuOe1ymfHZPu"
},
"outputs": [],
"source": [
"# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dfsDR_omdNea"
},
"source": [
"# Gemma - Prompt Chaining and Iterative Generation\n",
"This notebook demonstrates how to use prompt chaining and iterative generation with Gemma through a story writing example.\n",
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Prompt_chaining.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FaqZItBdeokU"
},
"source": [
"## Setup\n",
"\n",
"### Select the Colab runtime\n",
"To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
"\n",
"1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
"2. Select **Change runtime type**.\n",
"3. Under **Hardware accelerator**, select **T4 GPU**.\n",
"\n",
"### Gemma setup\n",
"\n",
"To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
"\n",
"* Get access to Gemma on kaggle.com.\n",
"* Select a Colab runtime with sufficient resources to run\n",
" the Gemma 2B model.\n",
"* Generate and configure a Kaggle username and an API key as Colab secrets.\n",
"\n",
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CY2kGtsyYpHF"
},
"source": [
"### Configure your credentials\n",
"\n",
"Add your your Kaggle credentials to the Colab Secrets manager to securely store it.\n",
"\n",
"1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
"2. Create new secrets: `KAGGLE_USERNAME` and `KAGGLE_KEY`\n",
"3. Copy/paste your username into `KAGGLE_USERNAME`\n",
"3. Copy/paste your key into `KAGGLE_KEY`\n",
"4. Toggle the buttons on the left to allow notebook access to the secrets.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "A9sUQ4WrP-Yr"
},
"outputs": [],
"source": [
"import os\n",
"from google.colab import userdata\n",
"\n",
"# Set the backbend before importing Keras\n",
"os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
"# Avoid memory fragmentation on JAX backend.\n",
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.00\"\n",
"\n",
"# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
"# vars as appropriate for your system.\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iwjo5_Uucxkw"
},
"source": [
"### Install dependencies\n",
"Run the cell below to install all the required dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "r_nXPEsF7UWQ"
},
"outputs": [],
"source": [
"!pip install -q -U tensorflow keras keras-nlp"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pOAEiJmnBE0D"
},
"source": [
"## Prompt chaining\n",
"\n",
"Prompt chaining is a powerful technique for managing complex tasks that are difficult to accomplish in a single step.\n",
"\n",
"It entails breaking a large task into smaller, linked prompts, where each prompt's output feeds into the next. This step-by-step method steers the language model through the process. Key advantages include:\n",
"\n",
"\n",
"\n",
"* Enhanced accuracy: Focused, smaller prompts produce better results from the language model.\n",
"* Easier debugging: Pinpointing and fixing issues within the chain is straightforward, allowing for precise improvements.\n",
"* Handling complexity: Dividing intricate problems into manageable steps enables the language model to address more complex tasks.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MTDrbUbiyhHL"
},
"source": [
"## Iterative Generation\n",
"Iterative generation is the process of creating the desired output step by step. This method is particularly useful for writing stories that exceed the length limitations of a single generation window. The advantages of iterative generation include:\n",
"\n",
"* Extended outputs: It enables the production of longer and more detailed content, going beyond the constraints of a single generation window.\n",
"* Enhanced flexibility: Adjustments and refinements can be made at each\n",
"iteration, ensuring the story progresses as intended.\n",
"* Human oversight: Feedback and guidance can be provided at each step, ensuring the story stays true to the creator's vision.\n",
"\n",
"\n",
"By using **prompt chaining** and **iterative generation** together, you can create an interesting and well-structured story, adding to it piece by piece, while still having control over how it unfolds."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J3sX2mFH4GWk"
},
"source": [
"### Gemma"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fz47tAgSKMNH"
},
"source": [
"**About Gemma**\n",
"\n",
"Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n",
"\n",
"Here's the [official documentation](https://ai.google.dev/gemma/docs/formatting) regarding prompting instruction-tuned models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B3WckZv2hef3"
},
"outputs": [],
"source": [
"import keras\n",
"import keras_nlp\n",
"from IPython.display import display, Markdown, Latex"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8dfseDZChhjl"
},
"outputs": [],
"source": [
"# Let's load Gemma using Keras\n",
"gemma_model_id = \"gemma2_instruct_2b_en\"\n",
"gemma = keras_nlp.models.GemmaCausalLM.from_preset(gemma_model_id)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fEhPz8PZuRcm"
},
"source": [
"## Story writing baseline: persona, premise, outline"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "u97asfO3uU-x"
},
"source": [
"As the first step, you'll create a persona, which the LLM should take to perform the task. In this example, you want it to act like a children's author, who wants to write a new funny and educating story."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "i4Qg0dYqu5UL"
},
"outputs": [],
"source": [
"persona = \"\"\"You are a children's author with a penchant for humorous, yet educating stories.\n",
"Your ultimate goal is to write a new story to be published in a children magazine.\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZhniNs61u-Ku"
},
"source": [
"The upcoming subtask for the LLM is to develop a premise for the story. To achieve this, you require a prompt. You'll use the persona description to generate the premise prompt. In this example, you want the model to create a story about bunnies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XHXTQdNvu7L0"
},
"outputs": [],
"source": [
"premise_prompt = f\"\"\"<start_of_turn>user\n",
"{persona}\n",
"\n",
"Write a single-sentence premise for a children's story featuring bunnies.<end_of_turn>\n",
"<start_of_turn>model\\n\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "DyOvDcBkv5s1"
},
"outputs": [
{
"data": {
"text/markdown": [
"When a mischievous bunny accidentally swaps his carrot patch with a giant sunflower field, he must learn the importance of teamwork and communication to get his carrots back. \n",
"<end_of_turn>"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"premise_response = gemma.generate(premise_prompt, max_length=512)\n",
"premise = premise_response[len(premise_prompt) :]\n",
"display(Markdown(premise))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IpJ55LJAvblz"
},
"source": [
"You'll use the generated premise to create a prompt for the next step, which will produce the story outline.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "c4FHQw0avkQ0"
},
"outputs": [],
"source": [
"outline_prompt = f\"\"\"<start_of_turn>user\n",
"{persona}\n",
"\n",
"You have a gripping premise in mind:\n",
"\n",
"{{premise}}\n",
"\n",
"Create an outline for your story's plot consisting of 5 key points.<end_of_turn>\n",
"<start_of_turn>model\\n\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PkT-L8i-v9xJ"
},
"outputs": [
{
"data": {
"text/markdown": [
"Here's a plot outline for a children's story about a mischievous bunny and his carrot patch swap:\n",
"\n",
"**Title:** Benny's Big Sunflower Swap\n",
"\n",
"**Key Points:**\n",
"\n",
"1. **The Mishap:** Benny, a playful bunny with a penchant for mischief, accidentally trips over a root and sends his entire carrot patch tumbling into a giant sunflower field. He's horrified! He can't imagine a world without his delicious carrots.\n",
"2. **The Sunflower Surprise:** Benny, bewildered and hungry, explores the sunflower field. He's amazed by the towering sunflowers, but they're not very tasty. He tries to eat the petals, but they're too tough. He even tries to climb the sunflowers, but they're too tall!\n",
"3. **The Sunflower Friends:** Benny meets a friendly sunflower named Sunny who explains that the sunflowers are actually quite helpful. They provide shade for the bees, attract butterflies, and even help the birds find food. Benny learns that the sunflower field is a vital part of the ecosystem.\n",
"4. **The Teamwork Challenge:** Benny realizes he needs to work with the sunflowers to get his carrots back. He learns that the sunflowers can't move the carrots, but they can help him find a way to get them back. He asks Sunny for help, and together they brainstorm ideas.\n",
"5. **The Carrot Rescue:** Benny and Sunny come up with a plan to use the sunflowers' height to reach the carrots. They work together, with Benny using his agility to climb the sunflowers and Sunny guiding him. They successfully retrieve the carrots, and Benny learns the importance of teamwork and communication.\n",
"\n",
"\n",
"**Humorous Elements:**\n",
"\n",
"* Benny's mischievous nature could lead to funny situations, like him trying to eat the sunflower seeds or trying to climb the sunflowers with his carrot patch stuck to his back.\n",
"* Sunny the sunflower could have a quirky personality, perhaps with a love for riddles or a habit of telling tall tales.\n",
"* The story could incorporate silly"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"full_outline_prompt = outline_prompt.format(premise=premise)\n",
"outline_response = gemma.generate(full_outline_prompt, max_length=512)\n",
"outline = outline_response[len(full_outline_prompt) :]\n",
"display(Markdown(outline))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9iYpDaIzv_hC"
},
"source": [
"Once you have the plan, you'd like Gemma to begin writing the story. In the prompt, include all the necessary information you've gathered so far: the persona, premise, and outline.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lWWLFF24pFOe"
},
"outputs": [],
"source": [
"starting_prompt = f\"\"\"<start_of_turn>user\n",
"{persona}\n",
"\n",
"You have a gripping premise in mind:\n",
"\n",
"{{premise}}\n",
"\n",
"Your imagination has crafted a narrative outline:\n",
"\n",
"{{outline}}\n",
"\n",
"First, silently review the outline and the premise. Consider how to start the\n",
"story.\n",
"\n",
"Your task is to write a part of the story that covers only the first point of the outline.\n",
"You are not expected to finish the whole story now.\n",
"Do not write about the next points, only the first one plot point!!!\n",
"\n",
"Try to write 10 sentences.\n",
"Remember, DO NOT WRITE A WHOLE STORY RIGHT NOW.<end_of_turn>\n",
"<start_of_turn>model\\n\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fdu7kVqVwRXQ"
},
"outputs": [
{
"data": {
"text/markdown": [
"Benny the bunny was a whirlwind of fluff and mischief. He loved nothing more than hopping through his carrot patch, nibbling on the plumpest, juiciest carrots. But today, Benny's playful antics took a turn for the disastrous. He tripped over a particularly stubborn root, sending a cascade of carrots tumbling into the air. Benny watched in horror as his entire patch, his precious treasure, disappeared into a sea of towering sunflowers. The sunflowers, with their bright yellow faces, seemed to mock him with their endless rows. Benny's heart sank. He couldn't imagine a world without his carrots! He hopped around, his ears drooping, his little nose twitching with despair. He needed his carrots, and he needed them now! But how could he possibly get them back? \n",
"<end_of_turn>"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"full_starting_prompt = starting_prompt.format(premise=premise, outline=outline)\n",
"starting_response = gemma.generate(full_starting_prompt, max_length=1000)\n",
"draft = starting_response[len(full_starting_prompt) :]\n",
"display(Markdown(draft))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MHnP5quYwYLu"
},
"source": [
"If you're pleased with the start of your story, you can continue it by further prompting the model with the text written so far. You can also add guidelines to help the model write appropriately and avoid concluding the story too quickly."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZdFyNiFQ4UZW"
},
"outputs": [],
"source": [
"guidelines = \"\"\"Writing Guidelines\n",
"\n",
"Remember, your main goal is to write as much as you can. If you get through\n",
"the story too fast, that is bad. Expand, never summarize. Don't repeat previous\n",
"parts of the story, only expand.\"\"\"\n",
"\n",
"\n",
"continuation_prompt = f\"\"\"<start_of_turn>user\n",
"{persona}\n",
"\n",
"You have a following premise in mind:\n",
"\n",
"{{premise}}\n",
"\n",
"The outline of the story looks like this:\n",
"\n",
"{{outline}}\n",
"\n",
"\n",
"Here's what you've written so far:\n",
"\n",
"{{story_text}}\n",
"\n",
"\n",
"=====\n",
"First, silently review the premise, the outline and the story you've written so far.\n",
"\n",
"Write the continuation - the next 5 sentences that cover the next outline point. Stick to the outline.\n",
"\n",
"However, once the story is COMPLETELY finished, write IAMDONE.\n",
"\n",
"{guidelines}<end_of_turn>\n",
"<start_of_turn>model\\n\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "M7VM7COm22Hz"
},
"outputs": [
{
"data": {
"text/markdown": [
"Benny hopped from sunflower to sunflower, his nose twitching, trying to catch a whiff of his lost carrots. He sniffed the air, hoping to catch a hint of their sweet scent, but all he could smell was the earthy aroma of the soil and the sweet nectar of the sunflowers. He felt a pang of sadness, realizing that he couldn't just eat the sunflowers. They were beautiful, but they weren't carrots! Benny needed a plan, and he needed it fast. \n",
"\n",
" \n",
"=====\n",
" \n",
" IAMDONE \n",
"<end_of_turn>"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"full_continuation_prompt = continuation_prompt.format(\n",
" premise=premise, outline=outline, story_text=draft\n",
")\n",
"continuation_response = gemma.generate(full_continuation_prompt, max_length=1000)\n",
"continuation = continuation_response[len(full_continuation_prompt) :]\n",
"display(Markdown(continuation))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yX8G_SlOxIB6"
},
"source": [
"Add the continuation to the initial draft, keep building the story iteratively, until 'IAMDONE' is seen"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "az0Zk9MA1XqU"
},
"outputs": [
{
"data": {
"text/markdown": [
"Benny the bunny was a whirlwind of fluff and mischief. He loved nothing more than hopping through his carrot patch, nibbling on the plumpest, juiciest carrots. But today, Benny's playful antics took a turn for the disastrous. He tripped over a particularly stubborn root, sending a cascade of carrots tumbling into the air. Benny watched in horror as his entire patch, his precious treasure, disappeared into a sea of towering sunflowers. The sunflowers, with their bright yellow faces, seemed to mock him with their endless rows. Benny's heart sank. He couldn't imagine a world without his carrots! He hopped around, his ears drooping, his little nose twitching with despair. He needed his carrots, and he needed them now! But how could he possibly get them back? \n",
"<end_of_turn>\n",
"\n",
"Benny hopped from sunflower to sunflower, his nose twitching, trying to catch a whiff of his lost carrots. He sniffed the air, hoping to catch a hint of their sweet scent, but all he could smell was the earthy aroma of the soil and the sweet nectar of the sunflowers. He felt a pang of sadness, realizing that he couldn't just eat the sunflowers. They were beautiful, but they weren't carrots! Benny needed a plan, and he needed it fast. \n",
"\n",
" \n",
"=====\n",
" \n",
" \n",
"<end_of_turn>"
],
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"draft = draft + \"\\n\\n\" + continuation\n",
"\n",
"while \"IAMDONE\" not in continuation:\n",
" full_continuation_prompt = continuation_prompt.format(\n",
" premise=premise, outline=outline, story_text=draft\n",
" )\n",
" continuation_response = gemma.generate(full_continuation_prompt, max_length=5000)\n",
" continuation = continuation_response[len(full_continuation_prompt) :]\n",
" draft = draft + \"\\n\\n\" + continuation\n",
"\n",
"# Remove 'IAMDONE' and print the final story\n",
"final = draft.replace(\"IAMDONE\", \"\").strip()\n",
"display(Markdown(final))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "[Gemma_2]Prompt_chaining.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}